{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['/home/tiendat/transformer-entropy-ids/notebooks', '/home/tiendat/miniconda3/envs/torchtf/lib/python39.zip', '/home/tiendat/miniconda3/envs/torchtf/lib/python3.9', '/home/tiendat/miniconda3/envs/torchtf/lib/python3.9/lib-dynload', '', '/home/tiendat/miniconda3/envs/torchtf/lib/python3.9/site-packages', '../']\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "\n",
    "print(sys.path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import math\n",
    "import argparse\n",
    "from dataPreprocess import DatasetPreprocess\n",
    "from torch.utils.data import DataLoader\n",
    "import numpy as np\n",
    "from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score, classification_report, confusion_matrix\n",
    "import os\n",
    "import torch.optim as optim\n",
    "import time\n",
    "import warnings\n",
    "import torch.nn.functional as F\n",
    "import time\n",
    "from torchsummary import summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TRAIN SIZE: 84688  TEST SIZE: 21188  SIZE: 105876  TRAIN RATIO: 80 %\n",
      "finish load data\n"
     ]
    }
   ],
   "source": [
    "seed = 1\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "#================================================================\n",
    "root_dir = \"../road/predict_fab_multi/TFRecord_w15_s15/1/\"\n",
    "window_size = 15\n",
    "d_model = 12\n",
    "max_time_position = 10000\n",
    "gran = 1e-7\n",
    "log_e = 2\n",
    "batch_size = 10\n",
    "\n",
    "train_dataset = DatasetPreprocess(root_dir, window_size, window_size, d_model, max_time_position, gran, log_e, is_train=True)\n",
    "test_dataset = DatasetPreprocess(root_dir, window_size,window_size, d_model, max_time_position, gran, log_e, is_train=False)\n",
    "\n",
    "\n",
    "print(\"TRAIN SIZE:\", len(train_dataset), \" TEST SIZE:\", len(test_dataset), \" SIZE:\", len(train_dataset)+len(test_dataset), \" TRAIN RATIO:\", round(len(train_dataset)/(len(train_dataset)+len(test_dataset))*100), \"%\")\n",
    "\n",
    "# 2 DataLoader? \n",
    "train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)\n",
    "test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size)\n",
    "print('finish load data')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/tiendat/miniconda3/envs/torchtf/lib/python3.9/site-packages/torch/nn/modules/transformer.py:282: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n",
      "  warnings.warn(f\"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}\")\n",
      "/home/tiendat/miniconda3/envs/torchtf/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "class Config:\n",
    "    def __init__(self, args):\n",
    "        self.model_name = 'IDS-Transformer_' + args.type\n",
    "        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "        \n",
    "        self.payload_size = 8 \n",
    "            \n",
    "        self.mode = args.mode\n",
    "        self.dout_mess = 10\n",
    "        self.d_model = self.dout_mess\n",
    "        self.nhead = 10 \n",
    "        \n",
    "        self.tse = args.tse\n",
    "        self.pad_size = args.window_size # 15\n",
    "        self.window_size =  args.window_size # 15\n",
    "        self.max_time_position = 10000\n",
    "        self.num_layers = 6\n",
    "        self.gran = 1e-6\n",
    "        self.log_e = 2\n",
    "        \n",
    "        if args.type == 'chd':\n",
    "            self.classes_num = 5\n",
    "        elif args.type == 'road_mas':\n",
    "            self.classes_num = 6\n",
    "        else: # road_fab\n",
    "            self.classes_num = 7 \n",
    "        \n",
    "            \n",
    "        self.batch_size = args.batch_size\n",
    "        self.epoch_num = args.epoch\n",
    "        self.lr = args.lr\n",
    "        self.root_dir = args.indir\n",
    "        \n",
    "        # self.root_dir = './data/Processed/TFRecord_w29_s29/2/'\n",
    "        # self.root_dir = './road/predict_fab_multi/TFRecord_w15_s15/1/'\n",
    "        # self.root_dir = './road/predict_mas/TFRecord_w15_s15/4/'\n",
    "        self.model_save_path = '../model/' + self.model_name + '/'\n",
    "        if not os.path.exists(self.model_save_path):\n",
    "            os.mkdir(self.model_save_path)\n",
    "        self.result_file = '/home/tiendat/transformer-entropy-ids/result/trans8_performance.txt'\n",
    "\n",
    "        self.isload_model = False  \n",
    "        self.start_epoch = 24  # The epoch of the loaded model\n",
    "        self.model_path = 'model/' + self.model_name + '/' + self.model_name + '_model_' + str(self.start_epoch) + '.pth' \n",
    "\n",
    "class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler):\n",
    "    def __init__(self, optimizer, warmup, max_iters):\n",
    "        self.warmup = warmup\n",
    "        self.max_num_iters = max_iters\n",
    "        super().__init__(optimizer)\n",
    "\n",
    "    def get_lr(self):\n",
    "        lr_factor = self.get_lr_factor(epoch=self.last_epoch)\n",
    "        return [base_lr * lr_factor for base_lr in self.base_lrs]\n",
    "\n",
    "    def get_lr_factor(self, epoch):\n",
    "        lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))\n",
    "        if epoch <= self.warmup:\n",
    "            lr_factor *= epoch * 1.0 / self.warmup\n",
    "        return lr_factor\n",
    "    \n",
    "class Autoencoder1D(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim):\n",
    "        super(Autoencoder1D, self).__init__()\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Linear(input_dim, 32),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(32, 16),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(16, output_dim)\n",
    "        )\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.encoder(x)\n",
    "        return F.relu(x)\n",
    "\n",
    "class Time_Positional_Encoding(nn.Module):\n",
    "    def __init__(self, embed, max_time_position, device):\n",
    "        super(Time_Positional_Encoding, self).__init__()\n",
    "        self.device = device\n",
    "\n",
    "    def forward(self, x, time_position):\n",
    "        out = x.permute(1, 0, 2)\n",
    "        out = out + nn.Parameter(time_position, requires_grad=False).to(self.device)\n",
    "        out = out.permute(1, 0, 2)\n",
    "        return out\n",
    "    \n",
    "class TransformerPredictor(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super(TransformerPredictor, self).__init__()\n",
    "        \n",
    "        self.ae = Autoencoder1D(12, config.dout_mess).to(config.device)\n",
    "        self.pad_size = config.pad_size\n",
    "        \n",
    "        self.position_embedding = Time_Positional_Encoding(config.d_model, config.max_time_position, config.device).to(config.device)\n",
    "        \n",
    "        self.encoder_layer = nn.TransformerEncoderLayer(d_model=config.d_model, nhead=config.nhead).to(config.device)\n",
    "        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=config.num_layers).to(config.device)\n",
    "        self.fc = nn.Linear(config.d_model, config.classes_num).to(config.device)\n",
    "        \n",
    "    def forward(self, header, sl_sum, mask, time_position):\n",
    "        x = torch.concat((header, sl_sum), dim=-1)\n",
    "        \n",
    "        ae_out = torch.empty((x.shape[0], 10, 0)).to(config.device)\n",
    "        for i in range(self.pad_size):\n",
    "            tmp = self.ae(x[:, i, :]).unsqueeze(2)\n",
    "            ae_out = torch.concat((ae_out, tmp), dim=2)\n",
    "        x = ae_out.permute(2, 0, 1)\n",
    "            \n",
    "        out = self.position_embedding(x, time_position)\n",
    "        out = self.transformer_encoder(out, src_key_padding_mask=mask)\n",
    "        out = out.permute(1, 0, 2)\n",
    "        out = torch.sum(out, 1)\n",
    "        out = self.fc(out)\n",
    "        return out\n",
    "\n",
    "args = argparse.Namespace(\n",
    "    indir=\"\",\n",
    "    window_size=15,\n",
    "    type=\"road_mas\",\n",
    "    epoch=200,\n",
    "    batch_size=32,\n",
    "    lr=0.0001,\n",
    "    mode=\"cb\",\n",
    "    tse=True\n",
    ")\n",
    "\n",
    "config = Config(args)\n",
    "model = TransformerPredictor(config)\n",
    "start_epoch = -1\n",
    "loss_func = nn.CrossEntropyLoss().to(config.device)\n",
    "opt = optim.Adam(model.parameters(), lr=config.lr)\n",
    "lr_scheduler = CosineWarmupScheduler(opt, warmup=50, max_iters=config.epoch_num*len(train_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Inference time: 0.005384206771850586 seconds\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import torch\n",
    "\n",
    "# Assuming 'inputs' is your input data\n",
    "inputs = [torch.randn(10, 15, 4), torch.randn(10, 15, 8), torch.randn(10, 15), torch.randn(10, 15, 10)]\n",
    "\n",
    "# Move the model and inputs to the same device\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model = model.to(device)\n",
    "inputs = [i.to(device) for i in inputs]\n",
    "\n",
    "# Make sure the model is in evaluation mode\n",
    "model.eval()\n",
    "\n",
    "# Warm up (for more accurate timing)\n",
    "with torch.no_grad():\n",
    "    _ = model(*inputs)\n",
    "\n",
    "# Time the inference\n",
    "start_time = time.time()\n",
    "with torch.no_grad():\n",
    "    _ = model(*inputs)\n",
    "end_time = time.time()\n",
    "\n",
    "print(f'Inference time: {end_time - start_time} seconds')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.\n",
      "[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.\n",
      "[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.\n",
      "[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.\n",
      "[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.\n",
      "MACs (G):  37.095\n",
      "Params (M):  0.259528\n"
     ]
    }
   ],
   "source": [
    "from thop import profile\n",
    "\n",
    "x = torch.rand((10, 15, 4)).to(config.device)\n",
    "y = torch.rand((10, 15, 8)).to(config.device)\n",
    "z = torch.rand((10, 15,)).to(config.device)\n",
    "a = torch.rand((10, 15, 10)).to(config.device)\n",
    "\n",
    "macs, params = profile(model, inputs=[x, y, z, a])\n",
    "print('MACs (G): ', macs/1000**2)\n",
    "print('Params (M): ', params/1000**2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "piptorchtf",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
